#include "limbo/ip/state.hpp"
#include "test-utils.h"

using namespace limbo;

uint8_t example_raw[]{
    0x45, 0x00, 0x00, 0x38, 0x82, 0xde, 0x00, 0x00, 0x01, 0x11, 0x2c, 0x60,
    0x0a, 0x00, 0x1f, 0x7c, 0xe0, 0x00, 0x00, 0xfb, 0x14, 0xe9, 0x14, 0xe9,
    0x00, 0x24, 0xb6, 0xaa, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00,
    0x00, 0x00, 0x00, 0x00, 0x04, 0x77, 0x70, 0x61, 0x64, 0x05, 0x6c, 0x6f,
    0x63, 0x61, 0x6c, 0x00, 0x00, 0x01, 0x00, 0x01};
auto example_chunk = Chunk(example_raw, sizeof(example_raw));

uint8_t payload[] = {0x14, 0xe9, 0x14, 0xe9, 0x00, 0x24, 0xb6, 0xaa, 0x00,
                     0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00,
                     0x00, 0x00, 0x04, 0x77, 0x70, 0x61, 0x64, 0x05, 0x6c,
                     0x6f, 0x63, 0x61, 0x6c, 0x00, 0x00, 0x01, 0x00, 0x01};
auto payload_chunk = Chunk(payload, sizeof(payload));

auto src = make_address("10.0.31.124");
auto dst = make_address("224.0.0.251");

using State = ip::State<void, ip::Opts::calculate_checksum>;
using Context = typename State::Context;
using Packet = typename State::Packet;

TEST_CASE("send", "[ip]") {
  char buff[sizeof(example_raw)];
  auto buff_chunk = Chunk(buff, sizeof(example_raw));
  auto ctx = Context{src, dst, 1, ip::Proto::udp, 0, 0, nullptr};
  auto state = State();
  state.init(0x82DD);

  SECTION("no enough space") {
    auto zero_chunk = Chunk(buff, 0);

    SECTION("with content") {
      auto result = Packet::send(state, ctx, zero_chunk, payload_chunk);
      REQUIRE(!result);
      CHECK(result.state() == &state);
      CHECK(result.error_code() == (uint32_t)Errc::no_enough_space);
    }
    SECTION("without content") {
      auto result = Packet::send<false>(state, ctx, zero_chunk, payload_chunk);
      REQUIRE(!result);
      CHECK(result.state() == &state);
      CHECK(result.error_code() == (uint32_t)Errc::no_enough_space);
    }
  }

  SECTION("with content") {
    auto result = Packet::send(state, ctx, buff_chunk, payload_chunk);
    REQUIRE(result);
    auto res_buff = result.consumed();
    CHECK((void *)res_buff.data() == (void *)buff);
    CHECK(res_buff.size() == example_chunk.size());
    CHECK(res_buff == example_chunk);
  }

  SECTION("without content") {
    auto result = Packet::send<false>(state, ctx, buff_chunk, payload_chunk);
    REQUIRE(result);
    auto res_buff = result.consumed();
    CHECK((void *)res_buff.data() == (void *)buff);
    CHECK(res_buff.size() == 20);
    CHECK(res_buff == Chunk(example_raw, 20));
  }
}

TEST_CASE("send (no checksum)", "[ip]") {
  using State = ip::State<void>;
  using Context = typename State::Context;
  using Packet = typename State::Packet;

  uint8_t example_modified[sizeof(example_raw)];
  memcpy(example_modified, example_raw, sizeof(example_raw));
  auto example_chunk = Chunk(example_modified, sizeof(example_modified));
  /* zero checksum */
  example_modified[10] = 0;
  example_modified[11] = 0;

  char buff[sizeof(example_raw)];
  auto buff_chunk = Chunk(buff, sizeof(example_raw));
  auto ctx = Context{src, dst, 1, ip::Proto::udp, 0, 0, nullptr};
  auto state = State();
  state.init(0x82DD);

  SECTION("with content") {
    auto result = Packet::send(state, ctx, buff_chunk, payload_chunk);
    REQUIRE(result);
    auto res_buff = result.consumed();
    CHECK((void *)res_buff.data() == (void *)buff);
    CHECK(res_buff.size() == example_chunk.size());
    CHECK(res_buff == example_chunk);
  }

  SECTION("without content") {
    auto result = Packet::send<false>(state, ctx, buff_chunk, payload_chunk);
    REQUIRE(result);
    auto res_buff = result.consumed();
    CHECK((void *)res_buff.data() == (void *)buff);
    CHECK(res_buff.size() == 20);
    CHECK(res_buff == Chunk(example_modified, 20));
  }
}

TEST_CASE("recv/success", "[ip]") {
  auto state = State();
  auto result = state.recv(example_chunk, nullptr);
  CHECK(result);
  CHECK(result.consumed() == example_chunk);
  CHECK(result.state() == &state);
  auto &packet = state.get_parsed();
  CHECK(packet.ver == 4);
  CHECK(packet.header_length == 5);
  CHECK(packet.tos == 0);
  CHECK(packet.length == example_chunk.size());
  CHECK(packet.id == 0x82DE);
  CHECK(packet.flag == 0);
  CHECK(packet.frag_offset == 0);
  CHECK(packet.ttl == 1);
  CHECK(packet.proto == (uint8_t)ip::Proto::udp);
  CHECK(packet.checksum == 0x2C60);
  CHECK(packet.source == src);
  CHECK(packet.destination == dst);
  CHECK(packet.options == Chunk(example_raw, 0));
  CHECK(packet.payload == payload_chunk);
  CHECK(packet.packet_start == example_chunk.data());
}

TEST_CASE("recv/incomplete", "[ip]") {
  auto state = State();
  auto empty_chunk = Chunk(example_raw, 0);

  SECTION("less then header sz") {
    auto result = state.recv(Chunk(example_raw, 18), nullptr);
    CHECK(result);
    CHECK(result.consumed() == empty_chunk);
    CHECK(result.demand() == 2);
  }

  SECTION("incomplete payload") {
    auto chunk = Chunk(example_raw, example_chunk.size() - 1);
    auto result = state.recv(chunk, nullptr);
    CHECK(result);
    CHECK(result.consumed() == empty_chunk);
    CHECK(result.demand() == 1);
  }
}

TEST_CASE("recv/options", "[ip]") {
  uint8_t example_raw[]{
      0x46, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x00, 0x01, 0x11, 0x84, 0x8c,
      0x0a, 0x00, 0x1f, 0x7c, 0xe0, 0x00, 0x00, 0xfb, 0x14, 0xe9, 0x14, 0xe9,
  };
  auto example_chunk = Chunk(example_raw, sizeof(example_raw));
  auto state = State();

  SECTION("success") {
    auto result = state.recv(example_chunk, nullptr);
    CHECK(result);
    CHECK(result.consumed() == example_chunk);
    auto &packet = state.get_parsed();
    CHECK(packet.ver == 4);
    CHECK(packet.header_length == 6);
    CHECK(packet.options == Chunk(example_raw + 20, 4));
  }

  SECTION("incomplete") {
    auto chunk = Chunk(example_raw, 21);
    auto result = state.recv(chunk, nullptr);
    auto empty_chunk = Chunk(example_raw, 0);
    CHECK(result);
    CHECK(result.consumed() == empty_chunk);
    CHECK(result.demand() == 3);
  }
}

TEST_CASE("recv/failures", "[ip]") {
  using State = ip::State<void, ip::Opts::validate_checksum>;

  auto state = State();
  uint8_t raw[sizeof(example_raw)];
  memcpy(raw, example_raw, sizeof(raw));
  auto raw_chunk = Chunk(raw, sizeof(raw));

  SECTION("wrong checksum") {
    ++raw[20 - 9];
    auto result = state.recv(raw_chunk, nullptr);
    CHECK(!result);
    CHECK(result.state() == &state);
    CHECK(result.error_code() == (uint32_t)Errc::ip_checksum_mismatch);
    CHECK(result.consumed() == raw_chunk);

    SECTION("ignore checksum") {
      using State = ip::State<void>;

      auto state = State();
      auto result = state.recv(raw_chunk, nullptr);
      CHECK(result);
    }
  }

  SECTION("wrong version") {
    raw[0] = 0;
    auto result = state.recv(raw_chunk, nullptr);
    CHECK(!result);
    CHECK(result.state() == &state);
    CHECK(result.error_code() == (uint32_t)Errc::ip_unsupported_version);
    CHECK(result.consumed() == Chunk(raw, 1));
  }

  SECTION("wrong header_sz") {
    raw[0] = 0x41;
    auto result = state.recv(raw_chunk, nullptr);
    CHECK(!result);
    CHECK(result.state() == &state);
    CHECK(result.error_code() == (uint32_t)Errc::ip_header_missized);
    CHECK(result.consumed() == Chunk(raw, 1));
  }
}

TEST_CASE("send & recv with checksum", "[ip]") {
  using O = ip::Opts;
  constexpr auto opts = O::calculate_checksum | O::validate_checksum;
  using State = ip::State<void, opts>;
  using Context = typename State::Context;
  using Packet = typename State::Packet;

  char buff[sizeof(example_raw)];
  auto buff_chunk = Chunk(buff, sizeof(example_raw));
  auto ctx = Context{src, dst, 1, ip::Proto::udp, 0, 0, nullptr};
  auto state = State();
  state.init(0x82DD);

  auto result = Packet::send(state, ctx, buff_chunk, payload_chunk);
  REQUIRE(result);
  auto res_buff = result.consumed();

  result = state.recv(res_buff, nullptr);
  CHECK(result);
  CHECK(result.state() == &state);
  CHECK(!result.demand());
}
